# Load specified SVD and generate peripheral memory maps & structures.
#@author Thomas Roth <thomas.roth@leveldown.de>, Ryan Pavlik <ryan.pavlik@gmail.com>
#@keybinding 
#@menupath 
#@toolbar

# More information:
# https://leveldown.de/blog/svd-loader/
# License: GPLv3

import sys

from cmsis_svd.parser import SVDParser
from ghidra.program.model.data import Structure, StructureDataType, UnsignedIntegerDataType, DataTypeConflictHandler
from ghidra.program.model.data import UnsignedShortDataType, ByteDataType, UnsignedLongLongDataType
from ghidra.program.model.mem import MemoryBlockType
from ghidra.program.model.address import AddressFactory
from ghidra.program.model.symbol import SourceType
from ghidra.program.model.mem import MemoryConflictException

class MemoryRegion:
	def __init__(self, name, start, end, name_parts=None):
		self.start = start
		self.end = end
		if name_parts:
			self.name_parts = name_parts
		else:
			self.name_parts = [name]

		assert(self.start < self.end)

	@property
	def name(self):
		return "_".join(self.name_parts)

	def length(self):
		return self.end - self.start

	def __lt__(self, other):
		return self.start < other.start

	def combine_with(self, other):
		return MemoryRegion(None,
			min(self.start, other.start),
			max(self.end, other.end),
			self.name_parts + other.name_parts)

	def combine_from(self, other):
		self.start = min(self.start, other.start)
		self.end = max(self.end, other.end)
		self.name_parts.extend(other.name_parts)
	
	def overlaps(self, other):
		if other.end < self.start:
			return False
		if self.end < other.start:
			return False
		return True
	
	def __str__(self):
		return "{}({}:{})".format(self.name, hex(self.start), hex(self.end))

def reduce_memory_regions(regions):
	regions.sort()
	print("Original regions: " + ", ".join(str(x) for x in regions))
	result = [regions[0]]
	for region in regions[1:]:
		if region.overlaps(result[-1]):
			result[-1].combine_from(region)
		else:
			result.append(region)

	print("Reduced regions: " + ", ".join(str(x) for x in result))
	return result

def calculate_peripheral_size(peripheral, default_register_size):
	size = 0
	for register in peripheral.registers:
		register_size = default_register_size if not register._size else register._size
		size = max(size, register.address_offset + register_size/8)
	return size


svd_file = askFile("Choose SVD file", "Load SVD File")

print("Loading SVD file...")
parser = SVDParser.for_xml_file(str(svd_file))
print("\tDone!")

# CM0, CM4, etc
cpu_type = parser.get_device().cpu.name
# little/big
cpu_endian = parser.get_device().cpu.endian

default_register_size = parser.get_device().size

# Not all SVDs contain these fields
if cpu_type and not cpu_type.startswith("CM"):
	print("Currently only Cortex-M CPUs are supported, so this might not work...")
	print("Supplied CPU type was: " + cpu_type)

if cpu_endian and cpu_endian != "little":
	print("Currently only little endian CPUs are supported.")
	print("Supplied CPU endian was: " + cpu_endian)
	sys.exit(1)

# Get things we need
listing = currentProgram.getListing()
symtbl = currentProgram.getSymbolTable()
dtm = currentProgram.getDataTypeManager()
space = currentProgram.getAddressFactory().getDefaultAddressSpace()

namespace = symtbl.getNamespace("Peripherals", None)
if not namespace:
	namespace = currentProgram.getSymbolTable().createNameSpace(None, "Peripherals", SourceType.ANALYSIS)

peripherals = parser.get_device().peripherals

print("Generating memory regions...")
# First, we need to generate a list of memory regions.
# This is because some SVD files have overlapping peripherals...
memory_regions = []
for peripheral in peripherals:
	start = peripheral.base_address
	length = peripheral.address_block.offset + peripheral.address_block.size
	end = peripheral.base_address + length

	memory_regions.append(MemoryRegion(peripheral.name, start, end))
memory_regions = reduce_memory_regions(memory_regions)

print("Generating memory blocks...")
# Create memory blocks:
for r in memory_regions:
	print("\t" + str(r))
	try:
		addr = space.getAddress(r.start)
		length = r.length()

		t = currentProgram.memory.createUninitializedBlock(r.name, addr, length, False)
		t.setRead(True)
		t.setWrite(True)
		t.setExecute(False)
		t.setVolatile(True)
		t.setComment("Generated by SVD-Loader.")
	except ghidra.program.model.mem.MemoryConflictException as e:
		print("\tFailed to generate due to conflict in memory block for: " + r.name)
		print("\t", e)
	except Exception as e:
		print("\tFailed to generate memory block for: " + r.name)
		print("\t", e)

print("\tDone!")

print("Generating peripherals...")
for peripheral in peripherals:
	print("\t" + peripheral.name)

	if(len(peripheral.registers) == 0):
		print("\t\tNo registers.")
		continue

	# try:
	# Iterage registers to get size of peripheral
	# Most SVDs have an address-block that specifies the size, but
	# they are often far too large, leading to issues with overlaps.
	length = calculate_peripheral_size(peripheral, default_register_size)

	# Generate structure for the peripheral
	peripheral_struct = StructureDataType(peripheral.name, length)

	peripheral_start = peripheral.base_address
	peripheral_end = peripheral_start + length
	print("\t\t{}:{}".format(hex(peripheral_start), hex(peripheral_end)))

	for register in peripheral.registers:
		register_size = default_register_size if not register._size else register._size

		r_type = UnsignedIntegerDataType()
		rs = register_size / 8
		if rs == 1:
			r_type = ByteDataType()
		elif rs == 2:
			r_type = UnsignedShortDataType()
		elif rs == 8:
			r_type = UnsignedLongLongDataType()

		print("\t\t\t{}({}:{})".format(register.name, hex(register.address_offset), hex(register.address_offset + register_size/8)))
		peripheral_struct.replaceAtOffset(register.address_offset, r_type, register_size/8, register.name, register.description)


	addr = space.getAddress(peripheral_start)


	dtm.addDataType(peripheral_struct, DataTypeConflictHandler.REPLACE_HANDLER)
	symtbl.createLabel(addr,
					peripheral.name,
					namespace,
					SourceType.USER_DEFINED)
	try:
		listing.createData(addr, peripheral_struct, False)
	except:
		print("\t\tFailed to generate peripheral " + peripheral.name)
